import numpy as np
# import pykalman
import matplotlib.pyplot as plt
import matplotlib        as mpl
import scipy.stats as stats
from sklearn.metrics import r2_score

plt.rcParams["font.family"] = "Arial"
import matplotlib.transforms as mtransforms

from matplotlib.backends.backend_pdf import PdfPages


def multipage(filename, figs=None, dpi=500):
    pp = PdfPages(filename)
    if figs is None:
        figs = [plt.figure(n) for n in plt.get_fignums()]
    for fig in figs:
        fig.savefig(pp, format='pdf')
    pp.close()


colorekf=(26./255, 44./255, 105./255)
colorstate=(62./255, 207./255, 117./255)
coloruncert=(52./255, 235./255, 235./255)
colorgrey=(0.5,0.5,0.5)
    
# create colormap
nfrac=16
upper = mpl.cm.Greens(np.arange(256))
mycmap0 = np.vstack(( np.ones((4,4)),np.power(upper[:-16,:],2) ,[0,0,0,1]))
                    # convert to matplotlib colormap
mycmap = mpl.colors.ListedColormap(mycmap0, name='myColorMap', N=mycmap0.shape[0])



n_iters = 173 
sz = (n_iters,2) # size of array
sz2d=(n_iters,2,2)

zJ_from_W = 5.1006447*3.154*0.71
zJtomm=0.121
pcolor='deeppink'
colorrts='goldenrod'
nbins=24
sdate=1850
data = np.genfromtxt(open("toyKFmodelData8.csv", "rb"),dtype=float, delimiter=',')
dates=data[:,0]
dates[0]=sdate
#temps=data[:, 1]+287
lCo2=np.log10(data[:,2])
opt_depth=data[:,3]*0.001 #*0.053) #/1000?0.000100025
anthro_clouds=(data[:,7]+1)
R_tvar=np.square(data[:,4]) #still in temperature units
Roc_tvar=np.square(data[:,6])/zJ_from_W/zJ_from_W
tsi=data[:,8]
ocean_heat_measured = data[:,5]
#critexp = np.ones(ln(transmd))

#data2 = np.genfromtxt(open("HadCRUT5.global.annual.csv", "rb"),dtype=float, delimiter=',')
temps=data[:, 1]
JonesOffset=13.85
offset= -np.mean(temps[(1960-sdate):(1990-sdate)]) +JonesOffset + 273.15 #Jones2013 13.7 to 14 (Jones1999)
print(offset)
temps=temps+offset
#pindavg= np.mean(temps[(1850-sdate):(1930-sdate)])
pindavg=286.7
heatCp=17

sig=5.6704e-8

a_refl= 0.834 #constant clearsky albedo
g_refl=0.909
sw_in=340.2

Tconst=286.64 #86.7



T02=287.5 #in 2002
Teq1850=np.mean(temps[0:25])
print("Teq1850",Teq1850)
dfaS=0.42/(sw_in*a_refl*g_refl)
dfaA=0.35/(sw_in*a_refl*g_refl)
powp1 = 1.3*4/3.22 #1.3
B1B0 = 12.74/sig/np.power( T02, 4-powp1 ) #15.45 #19.45
inbndf= a_refl*g_refl*9.068 #sw_in* 137.49
print("inbndf = " + str(inbndf*tsi[0]))

print("B1B0 = " + str(B1B0))

rad1850 = (sw_in*0.9318*a_refl*(1+dfaA*(Teq1850-T02))+anthro_clouds[0])*g_refl*(1+dfaS*(Teq1850-T02))
print("rad1850 = " + str(rad1850))
B1 = (rad1850 / 5.670e-8 / np.power( Teq1850, 4-powp1 ))+B1B0 *2.444 #594
print("B1 = " + str(B1))
B0 = B1B0/B1
print("B0 = " + str(B0))
outbndf= sig*B1
print("outbndf = " + str(outbndf))
shaldepth=86
deepdepth = 1141
Cs= 136.5*shaldepth/1000 #<17 like maybe 14
Cd= 136.5*deepdepth/1000
gad= 0.67
#epd = 1   #1.3
oc1850= 3.5 + 273.15 #absolute temperature of deep ocean
oc_meas = ocean_heat_measured/zJ_from_W #/deepdepth/1.55 + oc1850 #convert into a deep ocean temp



# intial parameters

H=np.eye(2) #emissions matrix

#covariance of the process noise

N = 30
moving_aves=np.empty(len(temps))
moving_aves[:] = np.nan
ocean_aves=np.empty(len(temps))
ocean_aves[:] = np.nan
cN=int(np.ceil(N/2));
fN=int(np.floor(N/2));

for i in range(cN,(len(temps)-fN+1)):
    lasta=i+fN+1;
    firsta=i-cN;
    moving_aves[i] = np.mean(temps[firsta:lasta]);
    ocean_aves[i] = np.mean(oc_meas[firsta:lasta]);

resid_temp = (temps - moving_aves);
resid_ocean = (oc_meas - ocean_aves);
covM= np.cov(np.array([resid_temp[cN:-fN],resid_ocean[cN:-fN]]));
print(covM);

Q=(covM/30);

def compute_slope(x2,ki, optdn=-1, lCo2n=-1,anthro_cloud=-101):
    k=int(ki)
    tsik=sw_in
    if (anthro_cloud<-100):
        anthro_cloud = anthro_clouds[k]
    if (optdn<0):
        optdn = opt_depth[k]
    if (lCo2n<=0):
        lCo2n= lCo2[k]
        tsik=tsi[k]
    [x,oc]= x2
    inboundd=tsik*inbndf/(optdn+9.7279)*((dfaS+dfaS)+2*(dfaS*dfaS)*(x-T02)+dfaS*anthro_cloud/sw_in/0.9318/a_refl)
    outgoingd= outbndf*(4-powp1)*np.power(x, 3-powp1)*(1-B0*lCo2n) #0.0655
#following jacobian matrix conventions
    return np.array([[(1 + (inboundd-outgoingd - gad*(1-Cs/Cd))/heatCp ) , gad/Cd/heatCp], \
                     [ (inboundd-outgoingd-gad*(1-Cs/Cd))*Cs/heatCp +gad*(1-Cs/Cd)+Cs  ,  ( 1 - (1-Cs/heatCp)*gad/Cd)]])

def compute_update(x2,ki, optdn=-1, lCo2n=-1,anthro_cloud=-101):
    k=int(ki)
    [x,H]= x2
    tsik=sw_in
    if (anthro_cloud<-100):
        anthro_cloud = anthro_clouds[k]
    if (optdn<0):
        optdn = opt_depth[k]
    if (lCo2n<=0):
        lCo2n= lCo2[k]
        tsik=tsi[k]
    #print(anthro_cloud);
    inbound=tsik*inbndf*(1+dfaA*(x-T02)+anthro_cloud/sw_in/0.9318/a_refl)*(1+dfaS*(x-T02))/(optdn+9.7279)
    outgoing= outbndf*np.power(x, 4-powp1 )*(1-B0*lCo2n)
    oc=(H - Cs*(x-Teq1850))/Cd+oc1850
    if (x)<280:
        raise Exception("Too Cold "+str(k+1850))
    #by convention addition of prior x, oc is not included
    Tchange=(inbound - outgoing - gad*(x-Teq1850 - oc + oc1850) )/heatCp
    return np.array([Tchange, Tchange*Cs + gad*(x-Teq1850 - oc + oc1850) ])


# allocate space for arrays

xhat=np.zeros(sz)      # a posteri estimate of x
P=np.zeros(sz2d)         # a posteri error estimate
F=np.zeros(sz2d)         # state transitions
xhatminus=np.zeros(sz) # a priori estimate of x
Pminus=np.zeros(sz2d)    # a priori error estimate
K=np.zeros(sz2d)         # gain or blending factor

xhathat=np.zeros(sz)   # smoothed a priori estimate of x
Phat=np.zeros(sz2d)      # smoothed posteri error estimate 
Khat=np.zeros(sz2d)      # smoothed gain or blending factor
Shat=np.zeros(sz2d)
S=np.zeros(sz2d)
xblind=np.zeros(sz)


lml=np.zeros(sz)
lsml=0
y=np.zeros(sz)
qqy=np.zeros(sz)

qqyh=[]
qqyk=[]

xnorm = np.linspace(-5.5, 5.5, 200)
qqyker=np.zeros((2,200))
qqykerall=[]
ynorm  = stats.norm.pdf(xnorm)

def ekf_future(startyr,Pstart,xhatstart,endyr,case,caseA,noVolcs=False):
    nfiter=endyr-startyr
    Phatf=np.zeros((nfiter,2,2))
    Phatf[0]=Pstart
    xhatf=np.zeros((nfiter,2))
    xhatf[0]=xhatstart
    ##gen volcanoes here
    #
    if(noVolcs):
        volcs = np.ones(nfiter)*3.75*0.001
    else:
        volcs = VolcanoFit.genEruptions(nfiter+30)*0.001
    #print(2*np.sqrt(Phatf[0][0,0]))
    for k in range(1,nfiter):
        xhatf[k]= xhatf[k-1] + compute_update(xhatf[k-1],0,volcs[k-1],case[startyr-2015+k-1],caseA[startyr-2015+k-1]+1)
        #print(xhatf[k], volcs[k-1],case[startyr-2015+k-1],caseA[startyr-2015+k-1])
        Fnow=compute_slope(xhatf[k-1],0,volcs[k-1],case[startyr-2015+k-1],caseA[startyr-2015+k-1]+1)         
        Phatf[k]= np.matmul(np.matmul(Fnow ,Phatf[k-1]), np.transpose(Fnow)) +Q
        #if(k==nfiter-1):
        #    print(Fnow)
        #    print(2*np.sqrt(Phatf[k][0,0]))
    return (xhatf,Phatf)
        

def ekf_run(z,n_iter,retPs=False):
    # intial guesses
    xhat[0] = [Teq1850,oc_meas[0]]
    xblind[0]= xhat[0]
    P[0,:,:] = 1.0
    P[0,1,1] = 20
    Pminus[0,:,:] = P[0,:,:]


    F[0]=compute_slope(xhat[0,:],0) #necessary for last step of RTS smoother

    for k in range(1,n_iter):
        # time update
        F[k]=compute_slope(xhat[k-1],k)
        

        xhatminus[k] = xhat[k-1] + compute_update(xhat[k-1],k)
        xblind[k]= xblind[k-1] + compute_update(xblind[k-1],k)
        
        Pminus[k] = np.matmul(np.matmul(F[k] ,P[k-1]), np.transpose(F[k])) +Q #*10

        # measurement update if(Rvary):

        S[k]= Pminus[k] + Q*30 + np.matrix([[R_tvar[k],0],[0,Roc_tvar[k]]])

        K[k] = np.matmul(Pminus[k],np.linalg.inv(S[k]))
        y[k]=z[k]-(xhatminus[k])
        xhat[k] = xhatminus[k]+np.matmul(K[k],y[k])
        P[k] = np.matmul((np.eye(2)- K[k] ),Pminus[k])
        stdevS=np.sqrt(np.abs(np.diag(S[k])))
        qqy[k]=y[k]/stdevS
        gmstpdf=stats.norm.pdf(xnorm,loc = qqy[k,0], scale = data[k,4]/stdevS[0])
        ohcapdf=stats.norm.pdf(xnorm,loc = qqy[k,1], scale = data[k,6]/zJ_from_W/stdevS[1])
        qqykerall.append([gmstpdf,ohcapdf])
        if k>1:
            qqyker[0] += gmstpdf/(n_iters-1)
            qqyker[1] += ohcapdf/(n_iters-1)
        #lml[k]= -0.5* (np.transpose(y[k])/S[k]*y[k] + np.log(S[k]) + np.log(2*np.pi)) #need to sort this out later
        


    xhathat[n_iter-1]=xhat[n_iter-1]
    Phat[n_iter-1]=P[n_iter-1]
    xhathat[0]=xhat[0]
    Phat[0]=P[0]

    lsml=0
    ybark=0
    

##    #compute moving averages
##    N = 30
##    cumsum, moving_aves = [0], []
##    for i, x in enumerate(temps, 1):
##        cumsum.append(cumsum[i-1] + x)
##        if i<N/2:
##            moving_aves.append(xhat[0][0])
##        if i>=N:
##            moving_ave = (cumsum[i] - cumsum[i-N])/N
##            #can do stuff with moving_ave here
##            moving_aves.append(moving_ave)
    if(True):
        for ik in range(2,n_iter+1):
        # RTS Smoother
            k=n_iter-ik
        # measurement update
            try:
                Khat[k] = np.matmul(np.matmul(P[k],np.transpose(F[k+1])),np.linalg.inv(Pminus[k+1])) #compute inverse for higher dimensions
            except:
                Khat[k] =Khat[k+1]
            xhathat[k] = xhat[k]+np.matmul(Khat[k],(xhathat[k+1]- xhat[k] - compute_update(xhat[k],k)))
            Phat[k] = P[k] + np.matmul(np.matmul(Khat[k],(Phat[k+1]- Pminus[k])),np.transpose(Khat[k]))
            yrts=z[k]-H *xhathat[k]
#        qqyh.append(float(yrts/np.sqrt(H*Phat[k]*np.transpose(H) +np.matrix([[R_tvar[k],0],[0,Roc_tvar[k]]]))))
            Shat[k]= Phat[k]+ Q*30  +np.matrix([[R_tvar[k],0],[0,Roc_tvar[k]]])
####        if (k<len(moving_aves) and k>N/2):
####            ybark= xhathat[k] -moving_aves[k]
####            lsml=lsml - 0.5*(np.log(np.abs(Phat[k])) + np.log(2*np.pi) + np.transpose(ybark)/np.abs(Phat[k])*ybark)
####            qqyh2.append(float(ybark/np.sqrt(np.abs(Phat[k]))))
##    print(sum(lml))
##    print(lsml)
    if (retPs):
        return xhat[0:n_iter], P[0:n_iter]
    else:
        return xhat[0:n_iter]


observ = np.transpose(np.array([temps,oc_meas]))
#print(observ)
this_xhat=ekf_run(observ,n_iters)
xh1s=this_xhat[:,0]
xh0s=xhatminus[:,0]
stdS=np.sqrt(np.abs(S))[:,0,0]
stdP=np.sqrt(np.abs(P))[:,0,0]
Plastretain=np.abs(P)[-1,:,:]
xhh1s=xhathat[:,0]
xlastretain=this_xhat[-1,:]
stdPh=np.sqrt(np.abs(Phat))[:,0,0]
stdSh=np.sqrt(np.abs(Shat))[:,0,0]

xh1d=this_xhat[:,1]
xh0d=xhatminus[:,1]
stdSd=np.sqrt(np.abs(S))[:,1,1]
stdPd=np.sqrt(np.abs(P))[:,1,1]

Rtvara=np.mean(R_tvar[-11:-1])+30*Q[0,0]
Roctvara=np.mean(Roc_tvar[-11:-1])+30*Q[1,1]

def plot_boilerplate(ax=plt.gca()):
    ax.set_ylim(286.1,288.3)
   # plt.yticks(np.arange(286.2,288.4,0.2))
    ax.set_xticks(np.arange(1850,2025+1,25))
    ax.set_xlim(1850,2025)
    ax.set_xlabel('Year')
    ax.set_ylabel('Temperature (K)')
    ax.tick_params( direction = 'in',bottom=True, top=True, left=True, right=True )

if (__name__ == "__main__") and True:
    r2 = r2_score(observ[:,0], xblind[:,0])
    print('r2 score for GMST to blind is', r2)
    r2 = r2_score(observ[:,1], xblind[:,1])
    print('r2 score for OCHA to blind is', r2)

    r2 = r2_score(moving_aves[15:-16], xh1s[15:-16]) #
    print('r2 score for 30-mean GMST to EBM-KF is', r2)
    r2 = r2_score(ocean_aves[15:-16], xh1d[15:-16])
    print('r2 score for 30-mean OCHA to EBM-KF is', r2)

    plt.rcParams['figure.figsize'] = (7, 6)
   # plt.figure(1)
    fig, (ax1,ax4)= plt.subplots(2, 1, figsize=(7,10), gridspec_kw={ "hspace": 0.3})
    plot_boilerplate(ax1)
    plot_boilerplate(ax4)
    ax1.plot(dates,temps,'o',label='noisy HadCRUT5 GMST measurements $Y_{n}$',markersize=2,color=colorgrey)
    ax1.fill_between(dates, temps-2*data[:,4], temps+2*data[:,4],label="associated 95% uncertainty", color="lightgrey")
    ax1.plot(dates,this_xhat[:,0],'-',label='a posteriori GMST EBM-KF state estimate $\^{T }_{n}}$', color=colorekf)
    ax1.plot(dates,xblind[:,0],'--',label='blind model GMST prediction $\~{T}_{n+1} = F( \~{T}_{n}, \~{H}_{n} ; [eCO_2]_n ,AOD_n,AC_{n})$',color='darkgoldenrod')
    ax1.legend(loc="upper left", fontsize="9")
    #fig.suptitle('Kalman Filter is a "weighted average" between Model Projections and Noisy Data')
    ax1.set_title('Global Mean Surface Temperature State')
    ax1.set_xlabel('Year')
    ax1.set_ylabel('Temperature (K)')
    ax1.set_ylim(286.1,288.3)
    ax1.set_yticks(np.arange(286.2,288.3,0.2))
    ax1.set_xlim(1850,2025)
    ax1b = ax1.twinx()
    mn, mx = ax1.get_ylim()
    ax1b.set_ylim(mn-273.15, mx-273.15)
    ax1b.set_yticks(np.arange(13,15.2,0.2))
    ax1b.set_ylabel('Temperature (°C)')
    #plt.savefig("Figure_1rev.pdf", format="pdf") #dpi=300,format="png")


    #plt.rcParams['figure.figsize'] = (7, 6)
    #plt.figure(2)
    plot_boilerplate(ax4)
    ax4.plot(dates,ocean_heat_measured,'o',label='noisy Zanna (2019) measurements $\Psi_{n}$',markersize=2,color=colorgrey)
    ax4.fill_between(dates, ocean_heat_measured-2*data[:,6], ocean_heat_measured+2*data[:,6],label="associated 95% uncertainty", color="lightgrey")
    ax4.plot(dates,this_xhat[:,1]*zJ_from_W,'-',label='a posteriori OHCA EBM-KF state estimate $\^{H }_{n}}$', color=colorekf)
    ax4.plot(dates,xblind[:,1]*zJ_from_W,'--',label='blind model OHCA prediction $\~{H}_{n+1} = F( \~{T}_{n}, \~{H}_{n} ; [eCO_2]_n ,AOD_n,AC_{n})$',color='darkgoldenrod')
    ax4.legend(loc="upper left", fontsize="9")
    ax4.set_title('Ocean Total Heat State')
    ax4.set_xlabel('Year')
    ax4.set_ylabel('Heat (ZJ)')
    ax4.set_ylim([-150,600])
    ax4b = ax4.twinx()
    mn, mx = ax4.get_ylim()
    ax4b.set_yticks(np.arange(-1,8,1))
    ax4b.set_ylim(mn*zJtomm/10, mx*zJtomm/10)
    
    ax4b.set_ylabel('Thermosteric Sea Level Rise (cm)')

    axes=(ax1,ax4)
    axlabels=['a)','b)','c)','d)']
    for i in range(2):
        ax=axes[i]
        label=axlabels[i]
          # label physical distance to the left and up:
        trans = mtransforms.ScaledTranslation(-(40)/72, 15/72, fig.dpi_scale_trans) #-20/72, 7/72
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans, fontsize='large', va='bottom')

    fig.savefig("weighted_average.pdf",format="pdf")
    fig.savefig("weighted_average.png", dpi=400,format="png")


    fig = plt.figure(figsize=(8.25,6))
    #(H - Cs*(x-Teq1850))/Cd+oc1850
    plt.plot(dates,(oc_meas - heatCp*(temps-Teq1850))/Cd+oc1850,'o',markersize=2,label='inferred from both Zanna and HADCRUT5 measurements',color=colorgrey)
    plt.plot(dates,(this_xhat[:,1]- heatCp*(this_xhat[:,0]-Teq1850))/Cd+oc1850,'-', label='a posteriori ocean temp. EBM-KF state estimate $\^{\\theta }_{n}}$',color=colorekf)
    plt.plot(dates,(xblind[:,1]- heatCp*(xblind[:,0]-Teq1850))/Cd+oc1850,'--',label='blind model prediction $\~{\\theta }_{n}}$',color='darkgoldenrod')
    plt.legend()
    ax=plt.gca()
    ax.set_ylabel('Temperature (K)')
    mn, mx = ax.get_ylim()
    ax1b = ax.twinx()
    ax1b.set_ylim(mn-273.15, mx-273.15)
    ax1b.set_ylabel('Temperature (°C)')
    ax.set_title("Deep Ocean Potential Temperature $\\theta$")
    fig.savefig("DOPT.pdf",format="pdf")
    fig.savefig("DOPT.png", dpi=400,format="png")   

    fig = plt.figure(figsize=(8.25,6))
    ax_dict = fig.subplot_mosaic(
    """
    ab
    ac
    """,
    gridspec_kw={
        "width_ratios": [7, 1.25], "wspace": 0.4, "hspace": 0.3, "left":0.1, "right":0.96
    },)
    plt.sca(ax_dict["a"])
    plot_boilerplate(ax_dict["a"])
    plt.yticks(np.arange(286.2,288.3,0.2))
    #plt.plot(dates,xhatminus,'b-',label='a priori EBM-KF estimate', linewidth=3.0)
    ax_dict["a"].set_title('Estimated Climate State with Extended Kalman Filter')
    ax_dict["a"].plot(dates,this_xhat[:,0],'-',label='a posteriori EBM-KF state estimate $\^{T }_{n}$', color=colorekf)
  #  plt.plot(dates,xh0s, 'k.')
    ax_dict["a"].fill_between(dates, xh0s-2*stdS, xh0s+2*stdS,label="95% CI from $S_n$ of measurements around pred. GMST $\^{T }_{n|n-1}$", color=coloruncert)
    ##plt.fill_between(dates, xh1s-stdS, xh1s+stdS, color=(62./255, 140./255, 210./255))
    ax_dict["a"].fill_between(dates, xh1s-2*stdP, xh1s+2*stdP,label="95% CI from $P_n$ of GMST state $\^{T }_{n}$", color=colorstate)
    ax_dict["a"].plot(dates,temps,'o',label='HadCRUT5 GMST measurements $Y_{n}$',markersize=2,color=colorgrey)
    
    #plt.fill_between(dates, xhh1s-std, xhh1s+std, color=(171./255, 245./255, 206./255))
    ax_dict["a"].legend(fontsize="9.5")
    ##plt.savefig("Figure_2.pdf", format = "pdf") #dpi=300,format="png")
    ax1b = ax_dict["a"].twinx()
    mn, mx = ax_dict["a"].get_ylim()
    ax1b.set_ylim(mn-273.15, mx-273.15)
    ax1b.set_yticks(np.arange(13,15.2,0.2))
    ax1b.set_ylabel('Temperature (°C)')
    #plt.rcParams['figure.figsize'] = (5, 4)
    ax1 = ax_dict["b"]
    ax2 = ax_dict["c"]
    #plt.subplots_adjust(wspace=0.4)
    stats.probplot(qqy[:,0], dist="norm", plot=ax2)
    ax1.set_title("Innovations")
    ax2.get_lines()[0].set_markerfacecolor(colorekf)
    ax2.get_lines()[0].set_markeredgewidth(0)
    ax2.get_lines()[1].set_color(pcolor)
    ax2.set_xlabel("Theoretical Quantiles",labelpad=0)
    ax2.set_title("")
    #ax1.hist(qqy[:,0], density=True,bins=nbins,color=colorekf)
    ax1.plot(xnorm,qqyker[0,:], color=colorekf)
    ax1.plot(xnorm,ynorm, color=pcolor)
    ax1.set_xlabel("Predict. Std Dev from $S_n$",labelpad=-3)
    ax1.set_ylabel("Probability Density")
  #  plt.rcParams['figure.figsize'] = (10, 8)
    axlabels=['a)','b)','c)']
    axs=(ax_dict["a"],ax1,ax2)
    for i in range(len(axs)):
        ax=axs[i]
        label=axlabels[i]
    # label physical distance to the left and up:
        trans = mtransforms.ScaledTranslation(-40/72, 1/72, fig.dpi_scale_trans) #-20/72, 7/72
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='large', va='bottom') #, fontfamily='serif')
    plt.savefig("mainGMST.pdf", format = "pdf")
    plt.savefig("mainGMST.png", dpi=400,format="png")




    fig = plt.figure(figsize=(8.25,6))
    ax_dict = fig.subplot_mosaic(
    """
    ab
    ac
    """,
    gridspec_kw={
        "width_ratios": [7, 1.25], "wspace": 0.4, "hspace": 0.3, "left":0.1, "right":0.96
    },)
    plt.sca(ax_dict["a"])
    plot_boilerplate(ax_dict["a"])
    
    #plt.plot(dates,xhatminus,'b-',label='a priori EBM-KF estimate', linewidth=3.0)
    ax_dict["a"].set_title('Estimated Ocean Total Heat State with EBM-KF')
    ax_dict["a"].set_ylabel('Heat (ZJ)')
    ax_dict["a"].plot(dates,xh0d*zJ_from_W,'-',label='a posteriori OHCA EBM-KF state estimate $\^{H }_{n}$', color=colorekf)
  #  plt.plot(dates,xh0s, 'k.')
    ax_dict["a"].fill_between(dates, (xh0d-2*stdSd)*zJ_from_W, (xh0d+2*stdSd)*zJ_from_W,label="95% CI from $S_n$ of measurements around pred. OHCA $\^{H }_{n|n-1}$", color=coloruncert)
    ##plt.fill_between(dates, xh1s-stdS, xh1s+stdS, color=(62./255, 140./255, 210./255))
    ax_dict["a"].fill_between(dates, (xh1d-2*stdPd)*zJ_from_W, (xh1d+2*stdPd)*zJ_from_W,label="95% CI from $P_n$ of OHCA state $\^{H }_{n}$", color=colorstate)
    ax_dict["a"].plot(dates,ocean_heat_measured,'o',label='noisy Zanna (2019) measurements $\Psi_{n}$',markersize=2,color=colorgrey)
    
    #plt.fill_between(dates, xhh1s-std, xhh1s+std, color=(171./255, 245./255, 206./255))
    ax_dict["a"].legend(fontsize="9.5")
    ##plt.savefig("Figure_2.pdf", format = "pdf") #dpi=300,format="png")
    ax_dict["a"].set_ylim([-150,600])
    ax4b = ax_dict["a"].twinx()
    mn, mx = ax_dict["a"].get_ylim()
    ax4b.set_ylim(mn*zJtomm/10, mx*zJtomm/10)
    ax4b.set_yticks(np.arange(-1,8,1))
    ax4b.set_ylabel('Thermosteric Sea Level Rise (cm)')

    #plt.rcParams['figure.figsize'] = (5, 4)
    ax1 = ax_dict["b"]
    ax2 = ax_dict["c"]
    #plt.subplots_adjust(wspace=0.4)
    stats.probplot(qqy[:,1], dist="norm", plot=ax2)
    ax1.set_title("Innovations")
    ax2.get_lines()[0].set_markerfacecolor(colorekf)
    ax2.get_lines()[0].set_markeredgewidth(0)
    ax2.get_lines()[1].set_color(pcolor)
    ax2.set_xlabel("Theoretical Quantiles",labelpad=0)
    ax2.set_title("")
    #ax1.hist(qqy[:,1], density=True,bins=nbins,color=colorekf)
    ax1.plot(xnorm,qqyker[1,:], color=colorekf)
    ax1.plot(xnorm,ynorm, color=pcolor)
    ax1.set_xlabel("Predict. Std Dev from $S_n$",labelpad=-3)
    ax1.set_ylabel("Probability Density")
  #  plt.rcParams['figure.figsize'] = (10, 8)
    axlabels=['a)','b)','c)']
    axs=(ax_dict["a"],ax1,ax2)
    for i in range(len(axs)):
        ax=axs[i]
        label=axlabels[i]
    # label physical distance to the left and up:
        trans = mtransforms.ScaledTranslation(-40/72, 1/72, fig.dpi_scale_trans) #-20/72, 7/72
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='large', va='bottom') #, fontfamily='serif')

    plt.savefig("mainOHCA.pdf", format = "pdf")
    plt.savefig("mainOHCA.png", dpi=400,format="png")


    plt.rcParams['figure.figsize'] = (7, 6)
    plt.figure()
    
    plt.title('Slight Changes with RT Smoother')
    plt.plot(dates,xhathat[:,0],'-',label='RTS smoothed GMST estimate $ \^{\^{T }}_{n}}$',color=colorrts)
    plt.plot(dates,xhat[:,0],'-',label='a posteriori GMST EBM-KF state estimate $\^{T }_{n}}$',color=colorekf)


    plt.fill_between(dates[4:], xhh1s[4:]-2*stdPh[4:], xhh1s[4:]+2*stdPh[4:],label="95% CI from $\hat{\hat {P_n}}$ of RTS GMST state $\^{\^{T }}_{n}$", color='goldenrod', alpha=0.5)
    plt.fill_between(dates[4:], xh1s[4:]-2*stdP[4:], xh1s[4:]+2*stdP[4:],label="95% CI from $P_n$ of EBM-KF GMST state $\^{T }_{n}$", color=colorstate, alpha=0.5)

    plt.plot(dates,temps,'o',label='GMST HadCRUT5 measurements $Y_{n}$',markersize=2,color=colorgrey)
    

    #tvar_xhat=ekf_run(temps,n_iters,True)
    #xh1stv=tvar_xhat[:,0]
   # stdPtv=np.sqrt(np.abs(P))[:,0]
   # plt.plot(dates,xh1stv,'-',label='EBM-KF-TVR estimate $ \\breve {T }_{n}}$',color="darkgreen")
   # plt.fill_between(dates, xh1stv-stdPtv, xh1stv+stdPtv,label="68% CI of EBM-KF-TVR state $\sqrt{ \\breve {P}_{n}}$", color="skyblue", alpha=0.5)
    
    ax=plt.gca()
    ax.set_ylabel('Temperature (K)')
    ax.set_xlabel('Year')
    ax.set_yticks(np.arange(286.4,288.2,0.2))
    plot_boilerplate()
    handles, labels = plt.gca().get_legend_handles_labels()
    order = [1,0,4,3,2]
    plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order]) #,prop={'size': 8})
    axb = ax.twinx()
    mn, mx = ax.get_ylim()
    axb.set_yticks(np.arange(13.2,15,0.2))
    axb.set_ylim(mn- 273.15, mx- 273.15)
    axb.set_ylabel('Temperature (°C)')
    
##    ax1.plot(xnorm,ynorm, color=pcolor)
##    ax1.set_xlabel("Model Predictive Standard Deviations")
##    ax1.set_ylabel("Probability Density")
    plt.savefig("RTS_smoother_changes.pdf", format = "pdf")
    plt.savefig("RTS_smoother_changes.png", dpi=400,format="png")


    custlevelsl = -np.flip(np.arange(-2.25,3.001,0.25))*np.log(4)-np.log(25)+np.log(16)
    SSPnames=[126,434,245,370,585]
    volcsplot=True
    gridspec = dict(wspace=0.4, width_ratios=[3.75,3.75, 0.45])
    fig, axs= plt.subplots(1, 3, figsize=(8.25,6),gridspec_kw=gridspec)
    plt.suptitle("Projected Surface Climate State")
    fig2, axs2= plt.subplots(1, 3, figsize=(8.25,6),gridspec_kw=gridspec)
    plt.suptitle("Projected Ocean Heat Content State")
    #axs[1].set_visible(False)
    #axs2[1].set_visible(False)
    end_temps=[[],[]]
    if(volcsplot):
      import VolcanoFit1 as VolcanoFit
      for rcpa in [0,1]:
          rcp=rcpa*2
          ax=axs[rcpa]
          ax2=axs2[rcpa]
          rcp=rcpa*3 #+1# to make it go to 4 and 8
          ax.plot(dates,xh1s,'-',label='past EBM-KF GMST state estimate $\^{T }_{n}}$',color=colorekf)
          ax.fill_between(dates, xh1s-2*stdP, xh1s+2*stdP,label="95% CI from $P_n$ of EBM-KF GMST state $\^{T }_{n}}$", color=colorstate, alpha=0.5)
          ax.plot(dates,temps,'o',label='GMST HadCRUT measurements $Y_{n}$',markersize=2,color=colorgrey)
          
          ax2.plot(dates,xh1d*zJ_from_W,'-',label='past EBM-KF OHCA state estimate $\^{ H }_{n}}$', color=colorekf)
          ax2.fill_between(dates, xh1d*zJ_from_W-2*stdPd*zJ_from_W, xh1d*zJ_from_W+2*stdPd*zJ_from_W,label="95% CI from $P_n$ of EBM-KF OHCA state $\^{ H }_{n}}$", color=colorstate, alpha=0.5)
          ax2.plot(dates,ocean_heat_measured,'o',label='noisy Zanna (2019) measurements $\Psi_{n}$',markersize=2,color=colorgrey)

          data3 = np.genfromtxt(open("KF6projectionSSP.csv", "rb"),dtype=float, delimiter=',')
          startf=2022 #not 2023!
          endf=2100
          inum=4
          pdftable0=np.zeros((5001,(endf-startf)*inum))
          pdftable1=np.zeros((4501,(endf-startf)*inum))
          tppdftable0=np.zeros((5001,(endf-startf)*inum))
          tppdftable1=np.zeros((4501,(endf-startf)*inum))
          heights0=np.linspace(286.5,291.5,5001)
          heights1=np.linspace(-1000,8000,4501)
          nsamps=12 #5min per 1000, 6000 looks great
          for i in range(0,nsamps):
              (fxhat,fP0)=ekf_future(startf,Plastretain,xlastretain,endf+1,np.log10(data3[:,1+rcp]),data3[:,6+rcp])
              stdfP=np.sqrt(np.abs(fP0))
              stdfS00=np.sqrt(np.abs(fP0[:,0,0]+Rtvara))
              stdfS11=np.sqrt(np.abs(fP0[:,1,1]+Roctvara))
              #print(2*stdfP[-1,0,0] ,2*stdfP[-1,1,1])
              end_temps[rcpa].append(fxhat[-1,:])
              if(i==0):
                  ax.plot(range(startf,endf+1),fxhat[:,0],'-',color='k',lw=0.2,label="Samples from volc distn of EBM-KF GMST") #colorekf
                  ax2.plot(range(startf,endf+1),fxhat[:,1]*zJ_from_W,'-',color='k',lw=0.2,label="Samples from volc distn of EBM-KF OHCA")
                  (fxhatnv,fP0nv)=ekf_future(startf,Plastretain,xlastretain,endf+1,np.log10(data3[:,1+rcp]),data3[:,6+rcp], noVolcs=True)
                  ax.plot(range(startf,endf+1),fxhatnv[:,0],'-',color='b',lw=0.6,label="Future EBM-KF GMST state, const. volc") #colorekf
                  ax2.plot(range(startf,endf+1),fxhatnv[:,1]*zJ_from_W,'-',color='b',lw=0.6,label="Future EBM-KF OHCA state, const. volc")
              elif(i<10):
                  ax.plot(range(startf,endf+1),fxhat[:,0],'-',color='k',lw=0.2)
                  ax2.plot(range(startf,endf+1),fxhat[:,1]*zJ_from_W,'-',color='k',lw=0.2)
                  #print(stdfP[0],stdfP[-2])
              for t in range(endf-startf):
                  for interp in range(inum):
                      iloc0=(fxhat[t,0]*(1-interp/inum)+fxhat[t+1,0]*(interp/inum))
                      iscale0=(stdfP[t,0,0]*(1-interp/inum)+stdfP[t+1,0,0]*(interp/inum))
                      pdftable0[:,t*inum+interp]=pdftable0[:,t*inum+interp]+stats.norm.pdf(heights0,loc=iloc0,scale=iscale0)
                      iloc1=(fxhat[t,1]*(1-interp/inum)+fxhat[t+1,1]*(interp/inum))*zJ_from_W
                      iscale1=(stdfP[t,1,1]*(1-interp/inum)+stdfP[t+1,1,1]*(interp/inum))*zJ_from_W
                      pdftable1[:,t*inum+interp]=pdftable1[:,t*inum+interp]+stats.norm.pdf(heights1,loc=iloc1,scale=iscale1)
                      iloc0=(fxhat[t,0]*(1-interp/inum)+fxhat[t+1,0]*(interp/inum))
                      iscale0=(stdfS00[t]*(1-interp/inum)+stdfS00[t+1]*(interp/inum))
                      tppdftable0[:,t*inum+interp]=tppdftable0[:,t*inum+interp]+stats.norm.pdf(heights0,loc=iloc0,scale=iscale0)
                      iloc1=(fxhat[t,1]*(1-interp/inum)+fxhat[t+1,1]*(interp/inum))*zJ_from_W
                      iscale1=(stdfS11[t]*(1-interp/inum)+stdfS11[t+1]*(interp/inum))*zJ_from_W
                      tppdftable1[:,t*inum+interp]=tppdftable1[:,t*inum+interp]+stats.norm.pdf(heights1,loc=iloc1,scale=iscale1)

          #v_newmax=max(pdftable[:,int((endf-startf)*inum/2)])/nsamps/100
          if(nsamps<6000):
              nsamps=6000
              fdata = np.load("SSP"+"{:.0f}".format(SSPnames[rcp])+".npz")
              cdftable0=fdata['cgmst']
              pdftable0=cdftable0.copy()
              pdftable0[1:,:] -= cdftable0[:-1,:]
              pdftable0=pdftable0*nsamps
              cdftable1=fdata['cohca']
              pdftable1=cdftable1.copy()
              pdftable1[1:,:] -= cdftable1[:-1,:]
              pdftable1=pdftable1*nsamps
          else:
              cdftable0=np.cumsum(pdftable0/nsamps,axis=0)
              cdftable1=np.cumsum(pdftable1/nsamps,axis=0)
              tpcdftable0=np.cumsum(tppdftable0/nsamps,axis=0)
              tpcdftable1=np.cumsum(tppdftable1/nsamps,axis=0)
              outfile=("SSP"+"{:.0f}".format(SSPnames[rcp]))
              np.savez_compressed(outfile,gmstheights=heights0,ohcaheights=heights1, cgmst=cdftable0,cohca=cdftable1, tpgmst=tpcdftable0,tpohca=tpcdftable1)
              
          CS=ax.contourf(np.arange(startf,endf,1/inum),heights0,np.log(pdftable0/nsamps*2/5+1e-20),levels=custlevelsl,cmap=mycmap,extend='both') #,vmax=v_newmax)
          l25cdftable0=np.absolute(cdftable0-25)
          l975cdftable0=np.absolute(cdftable0-975)
          ax.plot(np.arange(startf,endf,1/inum),heights0[np.argmin(l25cdftable0,axis=0)],'-',color=pcolor)
          ax.plot(np.arange(startf,endf,1/inum),heights0[np.argmin(l975cdftable0,axis=0)],'-',color=pcolor,label="2.5-97.5% CI of volc EBM-KF GMST")
          CS2=ax2.contourf(np.arange(startf,endf,1/inum),heights1,np.log(pdftable1/nsamps*250+1e-20),levels=custlevelsl,cmap=mycmap, extend='both')
          
          l25cdftable1=np.absolute(cdftable1-0.5*0.025)
          l975cdftable1=np.absolute(cdftable1-0.5*0.975)
          ax2.plot(np.arange(startf,endf,1/inum),heights1[np.argmin(l25cdftable1,axis=0)],'-',color=pcolor)
          ax2.plot(np.arange(startf,endf,1/inum),heights1[np.argmin(l975cdftable1,axis=0)],'-',color=pcolor,label="2.5-97.5% CI of volc EBM-KF OHCA")
          print("Maxiumum Densities")
          print(np.max(pdftable0/nsamps/10))
          print(np.max(pdftable1/nsamps/10))
          #CS.set_clim(custlevelsl[0],custlevelsl[-1]) 
          #CS2.set_clim(custlevelsl[0],custlevelsl[-1])


        
        #  plt.fill_between(range(2022,2100), fxhat-stdfP, fxhat+stdfP,label="68% CI of EBM-KF state ($\sqrt{ P_{n}}$)", color=colorstate, alpha=0.5)
          ax.set_xlabel('Year')
          
          ax.tick_params( direction = 'in',bottom=True, top=True, left=True, right=True )
          ax.set_xlim([1990,2100])
          ax.set_ylim(286.9,291.5)
          ax.set_yticks(np.arange(287,291.5,0.4))
          axb = ax.twinx()
          mn, mx = ax.get_ylim()
          axb.set_yticks(np.arange(14,18.8,0.4))
          axb.set_ylim(mn- 273.15, mx- 273.15)
          ax2.set_xlabel('Year')
          
          ax2.tick_params( direction = 'in',bottom=True, top=True, left=True, right=True )
          ax2.set_xlim([1990,2100])
          ax2.set_yticks(np.arange(0,3250,250))
          ax2.set_ylim([250,3000])

          ax4b = ax2.twinx()
          mn, mx = ax2.get_ylim()
          ax4b.set_yticks(np.arange(0,40,2))
          ax4b.set_ylim(mn*zJtomm/10, mx*zJtomm/10)
          if(rcpa==0):
              ax.set_ylabel('Temperature (K)')
              ax2.set_ylabel('Heat (ZJ)')
          else:
              axb.set_ylabel('Temperature (°C)')
              ax4b.set_ylabel('Thermosteric Sea Level Rise (cm)')

          if(rcpa==0):
              ax.legend(loc="best",prop={'size': 7})
              ax2.legend(loc="best",prop={'size': 7})

          label=axlabels[rcpa]
          ax.set_title(axlabels[rcpa]+"        SSP"+"{:.0f}".format(SSPnames[rcp])+" Future Projection        ") #"{:.1f}".format(
          ax2.set_title(axlabels[rcpa]+"        SSSP"+"{:.0f}".format(SSPnames[rcp])+" Future Projection        ")
          
          # label physical distance to the left and up:
          trans = mtransforms.ScaledTranslation(-(40-4*rcpa)/72, 15/72, fig.dpi_scale_trans) #-20/72, 7/72
          #ax.text(0.0, 1.0, label, transform=ax.transAxes + trans, fontsize='large', va='bottom')
          #ax2.text(0.0, 1.0, label, transform=ax.transAxes + trans, fontsize='large', va='bottom')
      cbar = fig.colorbar(CS,cax=axs[2],ticks=custlevelsl[0::2])
      cbar2 = fig.colorbar(CS2,cax=axs2[2],ticks=custlevelsl[0::2])
      cbar.ax.set_ylabel('Probability Density per 0.4K')
      cbar2.ax.set_ylabel('Probability Density per 250ZJ')
      cbar.ax.set_yticklabels(np.round(np.exp(custlevelsl[0::2]),8))
      cbar2.ax.set_yticklabels(np.round(np.exp(custlevelsl[0::2]),8))

    fig.savefig("futGMSTs.pdf", format = "pdf")
    fig.savefig("futGMSTs.png", dpi=400,format="png")
    fig2.savefig("futOHCAs.pdf", format = "pdf")
    fig2.savefig("futOHCAs.png", dpi=400,format="png")
    
    plt.show()



